{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## **Text Classification Eval Recipe**\n",
        "\n",
        "This Eval Recipe demonstrates how to compare performance of two models on a text classification prompt using [Vertex AI Evaluation Service](https://cloud.google.com/vertex-ai/generative-ai/docs/models/evaluation-overview)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-text-classification-eval-recipe&utm_medium=aRT-clicks&utm_campaign=text-classification-eval-recipe&destination=text-classification-eval-recipe&url=https%3A%2F%2Fcolab.research.google.com%2Fgithub%2FGoogleCloudPlatform%2Fapplied-ai-engineering-samples%2Fblob%2Fmain%2Fgenai-on-vertex-ai%2Fgemini%2Fmodel_upgrades%2Ftext_classification%2Fvertex_colab%2Ftext_classification_eval.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-text-classification-eval-recipe&utm_medium=aRT-clicks&utm_campaign=text-classification-eval-recipe&destination=text-classification-eval-recipe&url=https%3A%2F%2Fconsole.cloud.google.com%2Fvertex-ai%2Fcolab%2Fimport%2Fhttps%3A%252F%252Fraw.githubusercontent.com%252FGoogleCloudPlatform%252Fapplied-ai-engineering-samples%252Fmain%252Fgenai-on-vertex-ai%252Fgemini%252Fmodel_upgrades%252Ftext_classification%252Fvertex_colab%252Ftext_classification_eval.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://art-analytics.appspot.com/r.html?uaid=G-FHXEFWTT4E&utm_source=aRT-text-classification-eval-recipe&utm_medium=aRT-clicks&utm_campaign=text-classification-eval-recipe&destination=text-classification-eval-recipe&url=https%3A%2F%2Fconsole.cloud.google.com%2Fvertex-ai%2Fworkbench%2Fdeploy-notebook%3Fdownload_url%3Dhttps%3A%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fapplied-ai-engineering-samples%2Fmain%2Fgenai-on-vertex-ai%2Fgemini%2Fmodel_upgrades%2Ftext_classification%2Fvertex_colab%2Ftext_classification_eval.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/applied-ai-engineering-samples/blob/main/genai-on-vertex-ai/gemini/model_upgrades/text_classification/vertex_colab/text_classification_eval.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YYtBxd6y-Ju2"
      },
      "source": [
        "- Use case: given a Product Description find the most relevant Product Category from a predefined list of categories.\n",
        "\n",
        "- Metric: this eval uses a single deterministic metric \"Accuracy\" calculated by comparing model responses with ground truth labels. \n",
        "\n",
        "- Labeled evaluation dataset [dataset.jsonl](https://storage.googleapis.com/gemini_assets/classification_vertex/dataset.jsonl) is based [MAVE](https://github.com/google-research-datasets/MAVE/tree/main) dataset from Google Research. It includes 6 records that represent products from different categories. Each record includes two attributes:\n",
        "    - `product`: product name and description\n",
        "    - `reference`: correct product category name which serves as the ground truth label\n",
        "\n",
        "- Prompt template is a zero-shot prompt located in [prompt_template.txt](https://storage.googleapis.com/gemini_assets/classification_vertex/prompt_template.txt) with just one prompt variable `product` that maps to the `product` attribute in the dataset.\n",
        "\n",
        "Step 1 of 4: Configure eval settings\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZZnEA6GZ-kMW"
      },
      "outputs": [],
      "source": [
        "%%writefile .env\n",
        "PROJECT_ID=your-project-id            # Google Cloud Project ID\n",
        "LOCATION=us-central1                  # Region for all required Google Cloud services\n",
        "EXPERIMENT_NAME=eval-classification   # Creates Vertex AI Experiment to track the eval runs\n",
        "MODEL_BASELINE=gemini-1.5-flash-002   # Name of your current model\n",
        "MODEL_CANDIDATE=gemini-2.0-flash-001  # This model will be compared to the baseline model\n",
        "DATASET_URI=\"gs://gemini_assets/classification_vertex/dataset.jsonl\"  # Evaluation dataset in Google Cloud Storage\n",
        "PROMPT_TEMPLATE_URI=gs://gemini_assets/classification_vertex/prompt_template.txt  # Text file in Google Cloud Storage"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aqeOCM7k9t5h"
      },
      "source": [
        "Step 2 of 4: Install Python libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "id": "bR0rvHA3Lby6"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-cloud-aiplatform[evaluation] python-dotenv\n",
        "# The error \"session crashed\" is expected. Please ignore it and proceed to the next cell.\n",
        "import IPython\n",
        "IPython.Application.instance().kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9OUfM5Lz9_aM"
      },
      "source": [
        "Step 3 of 4: Authenticate and initialize Vertex AI"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VRFZFC6OLby7"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "import vertexai\n",
        "from dotenv import load_dotenv\n",
        "from google.cloud import storage\n",
        "\n",
        "load_dotenv(override=True)\n",
        "if os.getenv(\"PROJECT_ID\") == \"your-project-id\":\n",
        "    raise ValueError(\"Please configure your Google Cloud Project ID in the first cell.\")\n",
        "if \"google.colab\" in sys.modules:  \n",
        "    from google.colab import auth  \n",
        "    auth.authenticate_user()\n",
        "vertexai.init(project=os.getenv('PROJECT_ID'), location=os.getenv('LOCATION'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CQvekvLt9SWD"
      },
      "source": [
        "Step 4 of 4: Run the eval on both models and compare the Accuracy scores"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KfG9JG9VHaNw"
      },
      "outputs": [],
      "source": [
        "from datetime import datetime\n",
        "from IPython.display import clear_output\n",
        "from vertexai.evaluation import EvalTask, EvalResult, CustomMetric, MetricPromptTemplateExamples\n",
        "from vertexai.generative_models import GenerativeModel, HarmBlockThreshold, HarmCategory\n",
        "\n",
        "def case_insensitive_match(record: dict[str, str]) -> dict[str, float]:\n",
        "    response = record[\"response\"].strip().lower()\n",
        "    label = record[\"reference\"].strip().lower()\n",
        "    return {\"accuracy\": 1.0 if label == response else 0.0}\n",
        "\n",
        "def load_prompt_template() -> str:\n",
        "    blob = storage.Blob.from_string(os.getenv(\"PROMPT_TEMPLATE_URI\"), storage.Client())\n",
        "    return blob.download_as_string().decode('utf-8')\n",
        "\n",
        "def run_eval(model: str) -> EvalResult:\n",
        "  timestamp = f\"{datetime.now().strftime('%b-%d-%H-%M-%S')}\".lower()\n",
        "  return EvalTask(\n",
        "      dataset=os.getenv(\"DATASET_URI\"),\n",
        "      metrics=[CustomMetric(name=\"accuracy\", metric_function=case_insensitive_match)],\n",
        "      experiment=os.getenv('EXPERIMENT_NAME')\n",
        "  ).evaluate(\n",
        "      model=GenerativeModel(model),\n",
        "      prompt_template=load_prompt_template(),\n",
        "      experiment_run_name=f\"{timestamp}-{model.replace('.', '-')}\"\n",
        "  )\n",
        "\n",
        "baseline = run_eval(os.getenv(\"MODEL_BASELINE\"))\n",
        "candidate = run_eval(os.getenv(\"MODEL_CANDIDATE\"))\n",
        "clear_output()\n",
        "print(\"Baseline model accuracy:\", baseline.summary_metrics[\"accuracy/mean\"])\n",
        "print(\"Candidate model accuracy:\", candidate.summary_metrics[\"accuracy/mean\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EhZZ030LJGL3"
      },
      "source": [
        "This Eval Recipe is intended to be a simple starting point. Please use our [documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval) to learn about all available metrics and customization options."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "3.10.12",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
